import json
import pickle




# Optionally save the dictionary to a file if needed
with open('dataset_split_keys.json', 'r') as fp:
    dataset_split_keys = json.load(fp)

# Load data
with open('./data/query_and_description_2.json', 'r') as fp:
    data = json.load(fp)

with open('../data/statistics.pkl', 'rb') as fp:
    stat = pickle.load(fp)


with open('data/ALL_runed_data.json', 'r') as fp:
    pruned_data = json.load(fp)

train_apis = [stat[api]['action_names'] for api in dataset_split_keys['train']]


train_apis = set().union(*train_apis)

train_apis = {api.replace('.', '_') for api in train_apis}

test_data = [sample for sample in pruned_data if len(set(sample['pruned_apis']) & train_apis) == 0 ]

test_data = [sample for sample in test_data if len(sample['apis']) > 0 and len(sample['pruned_code']) > 0]

# test_data = test_data[:1200]

test_data_queries = {sample['query'] for sample in test_data}



print('len(test_data)', len(test_data))
remaining_data = [sample for sample in pruned_data if sample['query'] not in test_data_queries]


remaining_data = [sample for sample in remaining_data if len(sample['apis']) > 0 and len(sample['pruned_code']) > 0]

with open('./data/ood_test_data.json', 'w') as fp:
    json.dump(test_data, fp)


with open('./data/synthetic_train_data.json', 'w') as fp:
    json.dump(remaining_data, fp)
#
# print('len(remaining_data)', len(remaining_data))
